import argparse
import torch
from tqdm import tqdm
from pathlib import Path

from torch.utils.data.dataloader import DataLoader
from allennlp.data.dataset_readers.dataset_utils.span_utils import bio_tags_to_spans

from parse_config import ConfigParser
import model.pick as pick_arch_module
from data_utils.pick_dataset_predict import PICKDataset
from data_utils.pick_dataset_predict import BatchCollateFn
from utils.util import iob_index_to_str, text_index_to_str
import numpy as np
import random
from .pipeline import input_pipeline
def setup_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.mannual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cuda.deterministic = True
    torch.backends.cudnn.benchmark =True


def main(ocr_output,pick_model,device):

    input_data,input_image = input_pipeline(ocr_output)

    pick_model.eval()
    setup_seed(518)

    # setup dataset and data_loader instances
    test_dataset = PICKDataset(boxes_and_transcripts_data=input_data,
                               images=input_image,
                               resized_image_size=(480, 960),
                               ignore_error=False)

    test_data_loader = DataLoader(test_dataset, batch_size=1, shuffle=False,
                                  num_workers=1, collate_fn=BatchCollateFn(training=False))

    # setup predict_output path
    # output_path = Path(args.output_folder)
    #
    # output_path.mkdir(parents=True, exist_ok=True)
    # pipeline and save to file
    with torch.no_grad():
        for step_idx, input_data_item in tqdm(enumerate(test_data_loader)):
            for key, input_value in input_data_item.items():
                if input_value is not None and isinstance(input_value, torch.Tensor):
                    input_data_item[key] = input_value.to(device)

            # For easier debug.
            image_names = input_data_item["filenames"]
            output = pick_model(**input_data_item)
            logits = output['logits']  # (B, N*T, out_dim)
            new_mask = output['new_mask']
            image_indexs = input_data_item['image_indexs']  # (B,)
            text_segments = input_data_item['text_segments']  # (B, num_boxes, T)
            mask = input_data_item['mask']
            # List[(List[int], torch.Tensor)]
            best_paths = pick_model.decoder.crf_layer.viterbi_tags(logits, mask=new_mask, logits_batch_first=True)
            predicted_tags = []
            for path, score in best_paths:
                predicted_tags.append(path)

            # convert iob index to iob string
            decoded_tags_list = iob_index_to_str(predicted_tags)
            # union text as a sequence and convert index to string
            decoded_texts_list = text_index_to_str(text_segments, mask)

            final_result =[]
            for decoded_tags, decoded_texts, image_index in zip(decoded_tags_list, decoded_texts_list, image_indexs):
                # List[ Tuple[str, Tuple[int, int]] ]
                spans = bio_tags_to_spans(decoded_tags, [])
                spans = sorted(spans, key=lambda x: x[1][0])

                entities = []  # exists one to many case
                for entity_name, range_tuple in spans:
                    entity = dict(entity_name=entity_name,
                                  text=''.join(decoded_texts[range_tuple[0]:range_tuple[1] + 1]))
                    entities.append(entity)
                final_result.append(entities)
                # result_file = output_path.joinpath(Path(test_dataset.files_list[image_index]).stem + '.txt')
                # with result_file.open(mode='w') as f:
                #     for item in entities:
                #         f.write('{}\t{}\n'.format(item['entity_name'], item['text']))
    print(final_result)
    return final_result




if __name__ == '__main__':
    args = argparse.ArgumentParser(description='PyTorch PICK Testing')
    args.add_argument('-ckpt', '--checkpoint', default=None, type=str,
                      help='path to load checkpoint (default: None)')
    args.add_argument('--bt', '--boxes_transcripts', default=None, type=str,
                      help='ocr results folder including boxes and transcripts (default: None)')
    args.add_argument('--impt', '--images_path', default=None, type=str,
                      help='images folder path (default: None)')
    args.add_argument('-predict_output', '--output_folder', default='predict_results', type=str,
                      help='predict_output folder (default: predict_results)')
    args.add_argument('-g', '--gpu', default=-1, type=int,
                      help='GPU id to use. (default: -1, cpu)')
    args.add_argument('--bs', '--batch_size', default=1, type=int,
                      help='batch size (default: 1)')
    args = args.parse_args()
    main(args)
